import os
import subprocess
import re
import sys, os
# import matplotlib
# import matplotlib.pyplot as plt
# import matplotlib.colors as colors
# import matplotlib.cm as cmx
# from matplotlib.colors import LinearSegmentedColormap
#
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
# matplotlib.rcParams['font.family'] = 'STIXGeneral'
# matplotlib.rcParams['font.size'] = 10

import numpy as np
from subprocess import Popen, PIPE
from math import pi
norm = np.linalg.norm

from xyz_src.read import read_trj, read_intc, read_energy

def run_umb(iiter, xyzflat, vxyzflat=None, temp=100, step=100000, center=0, k=1, model="initial_model", stride=100, workdir="umb", damp=50):
    name = f'{workdir}/{iiter}'
    write_data(name, xyzflat, vxyzflat)
    write_inscript(workdir, iiter, temp, step, center, k, model, stride, damp, umb=True)
    lmp = os.environ.get("LMP_EXEC", "lmp_gcc")
    command = f"mpirun -np 1 {lmp} -in {workdir}/{iiter}.in"
    print("command: ", command)
    # try:
    with open(f"{workdir}/{iiter}.screen", "w+") as fout:
        res = subprocess.call(command.split(), stdout=fout)
    # except:
    #     raise NameError("fail to compute commit")

    newxyz = read_trj(name)
    allplumed = read_intc(name)
    bias = allplumed[:, 0]
    xi = allplumed[:, 1]
    intc = allplumed[:, 2:]
    colvar = intc[:, 1:3]
    ie = find_firste(name)
    alle = read_energy(name)
    alle = np.vstack([ie.reshape([1, -1]), alle])
    if (alle.shape[0]!=intc.shape[0]):
        raise NameError(f"alle.shape[0]{alle.shape[0]} != intc.shape[0]{intc.shape[0]}")
    if (alle.shape[0]!=newxyz.shape[0]):
        raise NameError(f"alle.shape[0]{alle.shape[0]} != newxyz.shape[0]{newxyz.shape[0]}")
    for suffix in ["-plumed.out", ".intc", ".energy", ".lammpstrj", "-plumed.dat", ".in", ".dat", ".log", ".screen"]:
        os.remove(f"{workdir}/{iiter}{suffix}")

    basin, ends = find_basin(colvar)
    # print("basin, end", basin, ends)

    # fig=plt.figure()
    # plt.scatter(intc[:, 1], intcoord[:, 2])
    # plt.xlim([-pi, pi])
    # plt.ylim([-pi, pi])
    # plt.tight_layout()
    # fig.savefig(f"{workdir}/{iiter}.png")

    np.savez(f"{workdir}/{iiter}.npz", xyz=newxyz, colvar=colvar,
            intc=intc,
            xi=xi, bias=bias,
            center=center, k=k, temp=temp,
            alle=alle, basin=basin, ends=ends)
    maxe = np.max(alle[:, 1])

    del allplumed
    del newxyz
    del alle
    del intc
    del bias
    del colvar

    return basin, ends, xi, maxe

def find_firste(iiter):
    pattern="Step"
    info = []
    with open(f"{iiter}.log") as fin:
        here=False
        for line in fin:
            if (here is True):
                info = np.array(list(map(float, line.split()[1:5])))
                here = False
            if re.search(pattern, line):
                here = True
    return np.array(info)


def find_basin(intc):

    C5 = np.array([[-3.05,2.15], [-2.04,3.43]])
    C7eq = np.array([[-1.84,0.11], [-1.00,1.82]])
    Cax= np.array([[0.7,-1.55], [1.35, -0.02]])

    center = np.array([[(C5[0, 0]+C5[1, 0])/2., (C5[0, 1]+C5[1, 1])/2.],
              [(C7eq[0, 0]+C7eq[1, 0])/2., (C7eq[0, 1]+C7eq[1, 1])/2.],
              [(Cax[0, 0]+Cax[1, 0])/2., (Cax[0, 1]+Cax[1, 1])/2.]])
    radius = np.array([[(-C5[0, 0]+C5[1, 0])/2., (-C5[0, 1]+C5[1, 1])/2.],
              [(-C7eq[0, 0]+C7eq[1, 0])/2., (-C7eq[0, 1]+C7eq[1, 1])/2.],
              [(-Cax[0, 0]+Cax[1, 0])/2., (-Cax[0, 1]+Cax[1, 1])/2.]])


    ends = -1
    basin = -1
    found_commit = False
    for iconfig in range(intc.shape[0]):
        if (intc[iconfig, 1] < -2.5):
            intc[iconfig, 1] += 2*pi
        if (intc[iconfig, 0] > 3):
            intc[iconfig, 0] -= 2*pi

    dist0 = norm((intc-center[0, :])/radius[0, :], axis=1)
    dist1 = norm((intc-center[1, :])/radius[1, :], axis=1)
    dist2 = norm((intc-center[2, :])/radius[2, :], axis=1)
    for iconfig in range(1, intc.shape[0]):
        if (found_commit == False):
            if (dist0[iconfig]<1):
                basin = 0
                found_commit = True
                ends = iconfig
            elif (dist1[iconfig]<1):
                basin = 1
                found_commit = True
                ends = iconfig
            elif (dist2[iconfig]<1):
                basin = 2
                found_commit = True
                ends = iconfig
    # if (found_commit == True):
    #     print("find basin", basin, ends, intc[ends], intc[-1])

    # if (basin == -1):
    #     dist0 = norm((intc[-1]-center[0, :])/radius[0, :])
    #     dist1 = norm((intc[-1]-center[1, :])/radius[1, :])
    #     dist2 = norm((intc[-1]-center[2, :])/radius[2, :])
    #     t=intc[-1, 1:3]
    #     print(f"cannot find basin {t} {intc[-1]} {dist0} {dist1} {dist2}")
    #     # raise NameError(f"cannot find basin {t} {intc[-1]} {dist0} {dist1} {dist2}")

    # raise NameError(f"test {ends} {basin} {dist0[ends]} {dist1[ends]} {dist2[ends]}")
    return basin, ends


def write_data(name, xyzflat, vxyzflat0=None):
    ele_type = [1, 2, 1, 1, 3, 4, 5, 6, 2, 7, 2, 1, 1, 1, 3, 4, 5, 6,
            2, 7, 7, 7]
    ele_q = [0.1123, -0.3662, 0.1123, 0.1123, 0.5972, -0.5679, -0.4157,
            0.2719, 0.0337, 0.0823, -0.1825, 0.0603, 0.0603, 0.0603,
            0.5973, -0.5679, -0.4157, 0.2719, -0.149, 0.0976, 0.0976,
            0.0976]
    vxyzflat = vxyzflat0
    natom = int(xyzflat.shape[0]/3.)
    # print("test test", natom, xyzflat.shape[0])
    with open(f"{name}.dat", "w+") as fout:
        print(before_xyz, file=fout)
        print("Atoms\n", file=fout)
        for i in range(natom):
            print(i+1, "1", ele_type[i], ele_q[i],
                    xyzflat[i*3], xyzflat[i*3+1], xyzflat[i*3+2],
                    file=fout)
        print("", file=fout)
        if (vxyzflat is not None):
            print("Velocities\n", file=fout)
            for i in range(natom):
                print(i+1, vxyzflat[i*3], vxyzflat[i*3+1],
                        vxyzflat[i*3+2],
                        file=fout)
            print("", file=fout)
        print(after_xyz, file=fout)

def write_inscript(workdir, iiter=0, temp=300, step=2000, center=0, k=1, model="initial_model", stride=100, damp=100, umb=True):
    script = md_script.format(workdir=workdir, iiter=iiter,
            temp=temp, step=step, center=center,k=k,
            model=model, stride=stride, damp=damp)
    plumed_script = plumed_header.format(workdir=workdir, iiter=iiter,
            temp=temp, step=step, center=center,k=k,
            model=model, stride=stride, damp=damp)
    if umb:
        plumed_script += plumed_umb.format(workdir=workdir, iiter=iiter,
                temp=temp, step=step, center=center,k=k,
                model=model, stride=stride, damp=damp)
    plumed_script += plumed_print.format(workdir=workdir, iiter=iiter,
            temp=temp, step=step, center=center,k=k,
            model=model, stride=stride, damp=damp)
    with open(f"{workdir}/{iiter}.in", "w+") as fout:
        print(script, file=fout)
    with open(f"{workdir}/{iiter}-plumed.dat", "w+") as fout:
        print(plumed_script, file=fout)

md_script="""log {workdir}/{iiter}.log
units		real

neigh_modify    once yes one 22 page 2200

atom_style	full
bond_style      harmonic
angle_style     harmonic
dihedral_style  harmonic
pair_style      lj/cut/coul/cut 10.0
pair_modify     mix arithmetic
special_bonds   amber
kspace_style    none
read_data       {workdir}/{iiter}.dat

thermo          {stride}
thermo_style custom step temp pe ke etotal press

velocity all create {temp} 4928{iiter} rot yes mom yes dist gaussian
fix mom all momentum 100 linear 1 1 1 angular

dump            1 all custom {stride} {workdir}/{iiter}.lammpstrj id x y z # vx vy vz
fix             1 all nve
fix             2 all langevin {temp} {temp} {damp} 3214{iiter}
variable e     equal pe
variable k     equal ke
variable t     equal etotal
variable T     equal temp
fix extra all print {stride} "$T $e $k $t" file {workdir}/{iiter}.energy
fix        plumed all plumed plumedfile {workdir}/{iiter}-plumed.dat outfile {workdir}/{iiter}-plumed.out
timestep        0.5
run             {step}   #1 ns
"""

plumed_header = """UNITS LENGTH=A TIME=fs
p1: POSITION ATOM=1
p2: POSITION ATOM=2
p3: POSITION ATOM=3
p4: POSITION ATOM=4
p5: POSITION ATOM=5
p6: POSITION ATOM=6
p7: POSITION ATOM=7
p8: POSITION ATOM=8
p9: POSITION ATOM=9
p10: POSITION ATOM=10
p11: POSITION ATOM=11
p12: POSITION ATOM=12
p13: POSITION ATOM=13
p14: POSITION ATOM=14
p15: POSITION ATOM=15
p16: POSITION ATOM=16
p17: POSITION ATOM=17
p18: POSITION ATOM=18
p19: POSITION ATOM=19
p20: POSITION ATOM=20
p21: POSITION ATOM=21
p22: POSITION ATOM=22
t1: TORSION ATOMS=2,5,7,9 NOPBC
t2: TORSION ATOMS=5,7,9,15 NOPBC
t3: TORSION ATOMS=7,9,15,17 NOPBC
t4: TORSION ATOMS=9,15,17,19 NOPBC
t5: TORSION ATOMS=17,15,9,11 NOPBC
t6: TORSION ATOMS=11,9,7,5 NOPBC
d1: DISTANCE ATOMS=19,17 NOPBC
d2: DISTANCE ATOMS=17,15 NOPBC
d3: DISTANCE ATOMS=15,9 NOPBC
d4: DISTANCE ATOMS=9,11 NOPBC
d5: DISTANCE ATOMS=9,7 NOPBC
d6: DISTANCE ATOMS=7,5 NOPBC
d7: DISTANCE ATOMS=5,2 NOPBC
d8: DISTANCE ATOMS=14,16 NOPBC
d9: DISTANCE ATOMS=5,6 NOPBC
a1: ANGLE ATOMS=9,17,15 NOPBC
a2: ANGLE ATOMS=17,15,9 NOPBC
a3: ANGLE ATOMS=15,9,7 NOPBC
a4: ANGLE ATOMS=9,7,5 NOPBC
a5: ANGLE ATOMS=7,5,2 NOPBC
a6: ANGLE ATOMS=15,9,11 NOPBC
a7: ANGLE ATOMS=11,9,7 NOPBC
a: ANN ARG=p1.x,p1.y,p1.z,p2.x,p2.y,p2.z,p3.x,p3.y,p3.z,p4.x,p4.y,p4.z,p5.x,p5.y,p5.z,p6.x,p6.y,p6.z,p7.x,p7.y,p7.z,p8.x,p8.y,p8.z,p9.x,p9.y,p9.z,p10.x,p10.y,p10.z,p11.x,p11.y,p11.z,p12.x,p12.y,p12.z,p13.x,p13.y,p13.z,p14.x,p14.y,p14.z,p15.x,p15.y,p15.z,p16.x,p16.y,p16.z,p17.x,p17.y,p17.z,p18.x,p18.y,p18.z,p19.x,p19.y,p19.z,p20.x,p20.y,p20.z,p21.x,p21.y,p21.z,p22.x,p22.y,p22.z MODELPATH={model} INPUT=X OUTPUT=z_x GRAD=dz_dx
"""
plumed_umb="""RESTRAINT ...
   LABEL=c
   ARG=a.0 AT={center} KAPPA={k}
... RESTRAINT
"""
plumed_print="""PRINT ARG=c.bias,a.0,t1,t2,t3,t4,t5,t6,d1,d2,d3,d4,d5,d6,d7,d8,d9,a1,a2,a3,a4,a5,a6,a7 STRIDE={stride} FILE={workdir}/{iiter}.intc
FLUSH STRIDE={stride}
"""

before_xyz = """ LAMMPS data file for ACE

22 atoms
21 bonds
36 angles
66 dihedrals

7 atom types
8 bond types
16 angle types
19 dihedral types

-100 100  xlo xhi
-100 100  ylo yhi
-100 100  zlo zhi

Masses

1 1.008
2 12.01
3 12.01
4 16.0
5 14.01
6 1.008
7 1.008

"""
after_xyz = """
Bonds

 1 3 2   3
 2 3 2   4
 3 3 1   2
 4 3 11 12
 5 3 11 13
 6 3 11 14
 7 5 9  10
 8 7 7   8
 9 5 19 20
10 5 19 21
11 5 19 22
12 7 17 18
13 1 5   6
14 2 5   7
15 4 2   5
16 1 15 16
17 2 15 17
18 6 9  11
19 4 9  15
20 8 7   9
21 8 17 19

Angles

1 2 5 7 8
2 4 4 2 5
3 5 3 2 4
4 4 3 2 5
5 5 1 2 3
6 5 1 2 4
7 4 1 2 5
8 2 15 17 18
9 5 13 11 14
10 5 12 11 13
11 5 12 11 14
12 9 10 9 11
13 10 10 9 15
14 11 9 11 12
15 11 9 11 13
16 11 9 11 14
17 12 8 7 9
18 13 7 9 10
19 16 21 19 22
20 16 20 19 21
21 16 20 19 22
22 12 18 17 19
23 13 17 19 20
24 13 17 19 21
25 13 17 19 22
26 1 6 5 7
27 3 5 7 9
28 6 2 5 6
29 7 2 5 7
30 1 16 15 17
31 3 15 17 19
32 8 11 9 15
33 6 9 15 16
34 7 9 15 17
35 14 7 9 11
36 15 7 9 15

Dihedrals

1 1 6 5 7 8
2 2 6 5 7 8
3 3 5 7 9 10
4 10 4 2 5 6
5 3 4 2 5 6
6 11 4 2 5 6
7 3 4 2 5 7
8 10 3 2 5 6
9 3 3 2 5 6
10 11 3 2 5 6
11 3 3 2 5 7
12 2 2 5 7 8
13 10 1 2 5 6
14 3 1 2 5 6
15 11 1 2 5 6
16 3 1 2 5 7
17 1 16 15 17 18
18 2 16 15 17 18
19 3 15 17 19 20
20 3 15 17 19 21
21 3 15 17 19 22
22 12 14 11 9 15
23 12 13 11 9 15
24 12 12 11 9 15
25 12 10 9 11 12
26 12 10 9 11 13
27 12 10 9 11 14
28 10 10 9 15 16
29 11 10 9 15 16
30 3 10 9 15 17
31 2 9 15 17 18
32 3 8 7 9 10
33 3 8 7 9 11
34 3 8 7 9 15
35 12 7 9 11 12
36 12 7 9 11 13
37 12 7 9 11 14
38 3 18 17 19 20
39 3 18 17 19 21
40 3 18 17 19 22
41 19 5 9 7 8
42 19 15 19 17 18
43 2 6 5 7 9
44 1 5 7 9 11
45 4 5 7 9 11
46 5 5 7 9 11
47 6 5 7 9 11
48 7 5 7 9 15
49 8 5 7 9 15
50 9 5 7 9 15
51 6 5 7 9 15
52 2 2 5 7 9
53 2 16 15 17 19
54 3 11 9 15 16
55 13 11 9 15 17
56 14 11 9 15 17
57 5 11 9 15 17
58 6 11 9 15 17
59 2 9 15 17 19
60 3 7 9 15 16
61 15 7 9 15 17
62 16 7 9 15 17
63 17 7 9 15 17
64 6 7 9 15 17
65 18 2 7 5 6
66 18 9 17 15 16

Pair Coeffs

1 0.01570000002623629  2.6495327872602221
2 0.10939999991572773  3.3996695084507409
3 0.086000000128358844 3.3996695079448309
4 0.20999999984182244  2.9599219016446874
5 0.16999999991766696  3.2499985240310356
6 0.015700000004219245 1.0690784617205229
7 0.015700000098461422 2.4713530426421655

Bond Coeffs

1 570.0 1.229
2 490.0 1.335
3 340.0 1.090
4 317.0 1.522
5 340.0 1.090
6 310.0 1.526
7 434.0 1.010
8 337.0 1.449

Angle Coeffs

1 80.0 122.90005267195104
2 50.0 120.00005142908158
3 50.0 121.90005224337536
4 50.0 109.50004692903693
5 35.0 109.50004692903693
6 80.0 120.40005160051184
7 70.0 116.60004997192425
8 63.0 111.10004761475803
9 50.0 109.50004692903693
10 50.0 109.50004692903693
11 50.0 109.50004692903693
12 50.0 118.04005047448166
13 50.0 109.50004692903693
14 80.0 109.70004701475206
15 63.0 110.10004718618234
16 35.0 109.50004692903693

Dihedral Coeffs

 1 2.0          1 1
 2 2.5         -1 2
 3 0.0          1 2
 4 2.0          1 2
 5 0.4          1 3
 6 0.0          1 4
 7 0.0          1 1
 8 0.272        1 2
 9 0.43         1 3
10 0.8          1 1
11 0.08        -1 3
12 0.155555556  1 3
13 0.20         1 1
14 0.20         1 2
15 0.45        -1 1
16 1.58        -1 2
17 0.55        -1 3
18 10.5        -1 2
19 1.10        -1 2
"""
